package org.zjvis.datascience.service;

import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import org.apache.poi.ss.usermodel.Workbook;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.converter.HttpMessageNotWritableException;
import org.springframework.stereotype.Service;
import org.zjvis.datascience.common.constant.Constant;
import org.zjvis.datascience.common.constant.DatabaseConstant;
import org.zjvis.datascience.common.constant.DatasetConstant;
import org.zjvis.datascience.common.dto.DataDto;
import org.zjvis.datascience.common.dto.DatasetDTO;
import org.zjvis.datascience.common.dto.SqlCategoryDTO;
import org.zjvis.datascience.common.dto.SqlQueryDTO;
import org.zjvis.datascience.common.dto.dataset.DatasetWithCategoryDTO;
import org.zjvis.datascience.common.model.Column;
import org.zjvis.datascience.common.util.JwtUtil;
import org.zjvis.datascience.common.util.SqlUtil;
import org.zjvis.datascience.common.util.StringUtil;
import org.zjvis.datascience.common.util.sqlParse.DataToFileUtil;
import org.zjvis.datascience.common.util.sqlParse.ParseUtil;
import org.zjvis.datascience.common.vo.SqlCategoryVO;
import org.zjvis.datascience.service.dataprovider.GPDataProvider;
import org.zjvis.datascience.service.mapper.DatasetMapper;
import org.zjvis.datascience.service.mapper.SqlCategoryMapper;

import javax.servlet.ServletContext;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.util.*;

/**
 * @description SqlQuery SQL查询数据集 Service
 * @date 2021-10-26
 */
@Service
public class SqlQueryService {

    private final static Logger logger = LoggerFactory.getLogger("SqlQueryService");

    @Autowired
    DatasetMapper datasetMapper;

    @Autowired
    GPDataProvider gpDataProvider;

    @Autowired
    SqlCategoryMapper sqlCategoryMapper;

    @Autowired
    private ServletContext servletContext;


    /**
     * 获取用户的所有表信息
     *
     * @param userId
     * @return
     */
    public List<DatasetDTO> getTablesByUserId(Long userId) {
        return datasetMapper.queryByUserId(userId);
    }

    public List<DatasetWithCategoryDTO> getTablesDetailByUserId(Long userId) {
        return datasetMapper.queryDetailByUserId(userId);
    }


    /**
     * 判断用户权限和sql合法性
     *
     * @param sql sql语句
     * @return
     */
    public SqlQueryDTO checkSql(String prefix, String sql) {
        Long userId = JwtUtil.getCurrentUserId();
        List<DatasetWithCategoryDTO> datasets = getTablesDetailByUserId(userId);
        Map<String, String> tablesMap = new Hashtable<>();
        for (DatasetWithCategoryDTO i : datasets) {
            JSONObject jsonObject = JSONObject.parseObject(i.getDataJson());
            String table = jsonObject.getString("table");
            table = jsonObject.getString("schema") + "." + SqlUtil.formatPGSqlColName(table);
            if (null != prefix && !prefix.equals("")) {
                tablesMap.put(i.getCategoryName() + "::" + i.getName(), table);
            }else {
                tablesMap.put(i.getName(), table);
            }
        }
        return ParseUtil.sqlParse(prefix, sql, tablesMap);
    }

    /**
     * 导出SQL查询结果为csv
     *
     * @param response
     * @param sql      sql查询语句
     * @return
     */
    public void exportToCsv(HttpServletResponse response, String sql) {
//        response.setContentType("application/csv;charset=utf-8");
        response.setContentType("application/force-download");
        response.setHeader("Content-Disposition", "attachment;filename=" + DatabaseConstant.GP_SQL_EXPORT_NAME + ".csv");
        DataDto data = null;
        if (null != sql && !sql.isEmpty()) {
            SqlQueryDTO sqlQueryDTO = queryDataBySql(sql, DatabaseConstant.GP_SQL_EXPORT_COUNT, false);
            if (sqlQueryDTO.getCode() != 200) {
                return;
            }
            data = sqlQueryDTO.getData();

        }else {
            data = DataDto.builder()
                     .head(ImmutableMap.of("header", "content"))
                     .data(Lists.newArrayList(ImmutableMap.of("header", "sql query result is empty")))
                     .build();
        }
        try {
            DataToFileUtil.DataToCsvStream(data, response.getOutputStream());
        } catch (IOException | HttpMessageNotWritableException e) {
            logger.error("something wrong when export data as CSV file, since {}", e.getMessage());
        }
    }

    /**
     * 导出SQL查询结果为xlsx
     *
     * @param response
     * @param sql      sql查询语句
     * @return
     */
    public void exportToXlsx(HttpServletResponse response, String sql) {
        DataDto data = null;
        if (null != sql && !sql.isEmpty()) {
            SqlQueryDTO sqlQueryDTO = queryDataBySql(sql, DatabaseConstant.GP_SQL_EXPORT_COUNT, false);
            if (sqlQueryDTO.getCode() != 200) {
                return;
            }
            data = sqlQueryDTO.getData();
        }else {
            data = DataDto.builder()
                    .head(ImmutableMap.of("header", "content"))
                    .data(Lists.newArrayList(ImmutableMap.of("header", "sql query result is empty")))
                    .build();
        }
        try {
            String fileName = new String(
                    (DatabaseConstant.GP_SQL_EXPORT_NAME + ".xlsx").getBytes(StandardCharsets.UTF_8),
                    "ISO-8859-1");
            Workbook workbook = DataToFileUtil.DataToXlsx(data);
            response.reset();
            response.setHeader("Content-Disposition", "attachment;filename=" + fileName);
            response.setHeader("Content-type", "application/octet-stream");
            response.setCharacterEncoding("UTF-8");
            OutputStream outputStream = response.getOutputStream();
            workbook.write(outputStream);
            outputStream.flush();
            outputStream.close();
        } catch (IOException e) {
            logger.error("something wrong when export data as XLSX file, since {}", e.getMessage());
        }
    }

//    /**
//     * 执行错误的sql语句
//     *
//     * @param sql
//     * @return
//     */
//    public SqlQueryDTO executeWrongSql(String sql) {
//        String dataSourceKey = (String) servletContext.getAttribute(Constant.DEFAULT_DATA_SOURCE_KEY);
//        Connection con = null;
//        PreparedStatement ps = null;
//        ResultSet rs = null;
//        SqlQueryDTO sqlQueryDTO = new SqlQueryDTO();
//        try {
//            //获取gp数据库中数据集库的连接（指定id为1）
//            gpDataService.executeQuerySQL(dataSourceKey, sql, JSONArray.class);
//            con = gpDataProvider.getConn(DatabaseConstant.DEFAULT_DATASET_ID);
//            ps = con.prepareStatement(sql);
//            rs = ps.executeQuery();
//        } catch (Exception e) {
//            sqlQueryDTO.setCode(400);
//            sqlQueryDTO.setErrMsg("SQL语法错误 : " + e.getMessage() + "-------" + e.toString());
//        } finally {
//            JDBCUtil.close(con, ps, rs);
//        }
//        return sqlQueryDTO;
//    }

    public SqlQueryDTO queryDataBySql(String sql, String category, int limit) {
        return queryDataBySql(sql, category, limit, false);
    }

    /***
     * 用户根据SQL查询，默认增加执行SQL的记录
     * @param sql
     * @param category
     * @return
     */
    public SqlQueryDTO queryDataBySql(String sql, String category, int limit, boolean needUpdate) {
        Long userId = JwtUtil.getCurrentUserId();
        SqlCategoryDTO sqlCategoryDTO = new SqlCategoryDTO(userId, sql);
        SqlQueryDTO sqlQueryDTO = checkSql(category, sql);
        if (sqlQueryDTO.getCode() == 400) {
            return sqlQueryDTO;
        }
        sql = ParseUtil.addLimit(sqlQueryDTO.getSql(), limit);
        logger.info("going to execute sql = {}", sql);
        sqlQueryDTO = executeSql(sql);
        if (sqlQueryDTO.getCode() == 200 && needUpdate) {
            sqlCategoryMapper.create(sqlCategoryDTO);
        }
        return sqlQueryDTO;
    }

    /***
     * 用户根据SQL查询，默认增加执行SQL的记录
     * @param sql
     * @param limit
     * @return
     */
    public SqlQueryDTO queryDataBySql(String sql, int limit) {
        return queryDataBySql(sql, limit, true);
    }

    /**
     * 用户根据SQL查询
     *
     * @param sql
     * @param limit      限制查询最大条数
     * @param needUpdate 是否需要创建查询记录
     * @return
     */
    public SqlQueryDTO queryDataBySql(String sql, int limit, boolean needUpdate) {
        sql = StringUtil.removeChineseQuote(sql);
        Long userId = JwtUtil.getCurrentUserId();
        SqlCategoryDTO sqlCategoryDTO = new SqlCategoryDTO(userId, sql);
        SqlQueryDTO sqlQueryDTO = checkSql("", sql);
        if (sqlQueryDTO.getCode() == 400) {
            return sqlQueryDTO;
        }
        sql = ParseUtil.addLimit(sqlQueryDTO.getSql(), limit);
        logger.info("going to execute sql = {}", sql);
        sqlQueryDTO = executeSql(sql);
        if (sqlQueryDTO.getCode() == 200 && needUpdate) {
            sqlCategoryMapper.create(sqlCategoryDTO);
        }
        return sqlQueryDTO;
    }


    /**
     * 执行sql查询语句
     *
     * @param sql
     * @return
     */
    public SqlQueryDTO executeSql(String sql) {
        String dataSourceKey = (String) servletContext.getAttribute(Constant.DEFAULT_DATA_SOURCE_KEY);

//        DataPattern dataPattern = new org.zjvis.datacenter.service.vo.table.Table(dataSourceKey, schema, table.replaceAll("\"", ""));
//        Connection con = null;
//        ResultSet rs = null;
//        PreparedStatement ps = null;
        SqlQueryDTO sqlQueryDTO = null;
        try {
//            con = gpDataProvider.getConn(DatabaseConstant.DEFAULT_DATASET_ID);
//            ps = con.prepareStatement(sql);
//            rs = ps.executeQuery();
            List<Column> columnList = gpDataProvider.queryTableMeta(dataSourceKey, sql);
            JSONArray queryResult = gpDataProvider.executeQuerySQL(dataSourceKey, sql, JSONArray.class);
//            List<Map<String, String>> rows = new ArrayList<>();
            Map<String, String> heads = new LinkedHashMap<>();
//            ResultSetMetaData meta = rs.getMetaData();
            /* 生成head结构 */
//            Map<String, String> metaMap = gpDataService.getColsMetaInfoMap(dataPattern).getResult();
            List<String> colNames = new ArrayList<>();
//            for (int i = 1; i < colCount + 1; i++) {
//                if (DatasetConstant.DEFAULT_ID_FIELD.equals(meta.getColumnName(i))) {
//                    continue;
//                }
//                String name = meta.getColumnName(i);
//                heads.put(name, SqlUtil.changeType(meta.getColumnTypeName(i)));
//                colNames.add(name);
//            }
            for (Column col: columnList) {
                if (DatasetConstant.DEFAULT_ID_FIELD.equals(col.getName())) {
                    continue;
                }
                heads.put(col.getName(), SqlUtil.changeType(col.getTypeName()));
                colNames.add(col.getName());
            }
            /* 生成data结构 */
            List<Map> rows = queryResult.toJavaList(Map.class);
//            while (rs.next()) {
//                Map<String, String> column = new HashMap<>();
//                for (String colName : colNames) {
//                    if (DatasetConstant.DEFAULT_ID_FIELD.equals(colName)) {
//                        continue;
//                    }
//                    column.put(colName, rs.getString(colName));
//                }
//                rows.add(column);
//            }
            DataDto data = new DataDto(heads, rows);
            sqlQueryDTO = new SqlQueryDTO(200, data);
        } catch (Exception e) {
            sqlQueryDTO = new SqlQueryDTO(500, e.getMessage());
        }
//        finally {
//            JDBCUtil.close(con, ps, rs);
//        }
        return sqlQueryDTO;
    }

//    /**
//     * 更新gp
//     *
//     * @param list
//     * @param incrementColumn
//     * @return
//     */
//    public void update(List<Map<String, String>> list, String incrementColumn, String tableName) {
//        Connection con = null;
//        try {
//            con = gpDataProvider.getConn(DatabaseConstant.DEFAULT_DATASET_ID);
//            con.setAutoCommit(false);
//            Statement stmt = con
//                .createStatement(ResultSet.TYPE_SCROLL_SENSITIVE, ResultSet.CONCUR_READ_ONLY);
//            for (Map<String, String> map : list) {
//                StringBuilder sql = new StringBuilder("update dataset." + tableName + " set ");
//                for (String key : map.keySet()) {
//                    if (!key.equals(incrementColumn)) {
//                        sql.append(key).append(" = ").append("'").append(map.get(key)).append("'");
//                    }
//                }
//                sql.append(" where ").append(incrementColumn).append(" = ")
//                    .append(map.get(incrementColumn));
//                stmt.addBatch(sql.toString());
//            }
//            stmt.executeBatch();
//            con.commit();
//        } catch (Exception e) {
//            logger.error("update error");
//        } finally {
//            JDBCUtil.close(con, null, null);
//        }
//    }

    public List<SqlCategoryDTO> getSqlCategory(Long userId) {
        return sqlCategoryMapper.queryByUserId(userId);
    }

    /**
     * 删除用户所有SQL查询日志
     *
     * @param userId
     * @return
     */
    public int deleteAllSqlCategory(Long userId) {
        return sqlCategoryMapper.deleteAll(userId);
    }

    /**
     * 删除用户所选的SQL历史查询日志
     *
     * @param sqlCategoryVOS
     * @param userId
     * @return
     */
    public List<Long> deleteSqlCategory(List<SqlCategoryVO> sqlCategoryVOS, Long userId) {
        List<Long> errors = new ArrayList<>();
        for (SqlCategoryVO i : sqlCategoryVOS) {
            if (!sqlCategoryMapper.delete(i.getId(), userId)) {
                errors.add(i.getId());
            }
        }
        return errors;
    }
}
